#  Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""Utilities for run templates."""

from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from pydantic import create_model
from pydantic.fields import FieldInfo

from zenml.config import ResourceSettings
from zenml.config.base_settings import BaseSettings
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
from zenml.config.source import SourceWithValidator
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.enums import StackComponentType
from zenml.logger import get_logger
from zenml.stack import Flavor
from zenml.zen_stores.schemas import (
    PipelineSnapshotSchema,
)

if TYPE_CHECKING:
    from zenml.config.pipeline_configurations import PipelineConfiguration
    from zenml.config.step_configurations import Step

logger = get_logger(__name__)


def validate_snapshot_is_templatable(
    snapshot: PipelineSnapshotSchema,
) -> None:
    """Validate that a snapshot is templatable.

    Args:
        snapshot: The snapshot to validate.

    Raises:
        ValueError: If the snapshot is not templatable.
    """
    if not snapshot.build:
        raise ValueError(
            "Unable to create run template as there is no associated build. "
            "Run templates can only be created for remote orchestrators that "
            "use container images to run the pipeline."
        )

    if not snapshot.build.stack:
        raise ValueError(
            "Unable to create run template as the associated build has no "
            "stack reference."
        )

    for component in snapshot.build.stack.components:
        if not component.flavor_schema:
            raise ValueError(
                "Unable to create run template as a component of the "
                "associated stack has no flavor."
            )

        if component.flavor_schema.is_custom:
            raise ValueError(
                "Unable to create run template as a component of the "
                "associated stack has a custom flavor."
            )

        flavor_model = component.flavor_schema.to_model()
        flavor = Flavor.from_model(flavor_model)
        component_config = flavor.config_class(
            **component.to_model(include_metadata=True).configuration
        )

        if component_config.is_local:
            raise ValueError(
                "Unable to create run template as the associated stack "
                "contains local components."
            )


def generate_config_template(
    snapshot: PipelineSnapshotSchema,
    pipeline_configuration: "PipelineConfiguration",
    step_configurations: Dict[str, "Step"],
) -> Dict[str, Any]:
    """Generate a run configuration template for a snapshot.

    Args:
        snapshot: The snapshot.
        pipeline_configuration: The pipeline configuration.
        step_configurations: The step configurations.

    Returns:
        The run configuration template.
    """
    steps_configs = {
        name: step.step_config_overrides.model_dump(
            include=set(StepConfigurationUpdate.model_fields),
            exclude={"name", "outputs"},
            exclude_none=True,
            exclude_defaults=True,
        )
        for name, step in step_configurations.items()
    }

    for config in steps_configs.values():
        config.get("settings", {}).pop("docker", None)

    pipeline_config_exclude = {"schedule", "build"}
    if not snapshot.is_dynamic:
        pipeline_config_exclude.add("parameters")

    pipeline_config = pipeline_configuration.model_dump(
        include=set(PipelineRunConfiguration.model_fields),
        exclude=pipeline_config_exclude,
        exclude_none=True,
        exclude_defaults=True,
    )

    pipeline_config.get("settings", {}).pop("docker", None)

    config_template = {
        "run_name": snapshot.run_name_template,
        "steps": steps_configs,
        **pipeline_config,
    }
    return config_template


def generate_config_schema(
    snapshot: PipelineSnapshotSchema,
    pipeline_configuration: "PipelineConfiguration",
    step_configurations: Dict[str, "Step"],
) -> Dict[str, Any]:
    """Generate a run configuration schema for the snapshot.

    Args:
        snapshot: The snapshot schema.
        pipeline_configuration: The pipeline configuration.
        step_configurations: The step configurations.

    Returns:
        The generated schema dictionary.
    """
    # Config schema can only be generated for a runnable snapshot, so this is
    # guaranteed by checks in the snapshot schema
    assert snapshot.build
    assert snapshot.build.stack

    stack = snapshot.build.stack
    experiment_trackers = []
    step_operators = []

    settings_fields: Dict[str, Any] = {
        "resources": (Optional[ResourceSettings], None)
    }
    for component in stack.components:
        if not component.flavor_schema:
            continue

        flavor_model = component.flavor_schema.to_model()
        flavor = Flavor.from_model(flavor_model)

        for class_ in flavor.config_class.__mro__[1:]:
            # Ugly hack to get the settings class of a flavor without having
            # the integration installed. This is based on the convention that
            # the static config of a stack component should always inherit
            # from the dynamic settings.
            if issubclass(class_, BaseSettings):
                if len(class_.model_fields) > 0:
                    settings_key = f"{component.type}.{component.flavor}"
                    settings_fields[settings_key] = (
                        Optional[class_],
                        None,
                    )

                break

        if component.type == StackComponentType.EXPERIMENT_TRACKER:
            experiment_trackers.append(component.name)
        if component.type == StackComponentType.STEP_OPERATOR:
            step_operators.append(component.name)

    settings_model = create_model("Settings", **settings_fields)

    generic_step_fields: Dict[str, Any] = {}

    for key, field_info in StepConfigurationUpdate.model_fields.items():
        step_config_exclude = [
            "name",
            "outputs",
            "step_operator",
            "experiment_tracker",
            "parameters",
        ]
        if not snapshot.is_dynamic:
            step_config_exclude.append("runtime")

        if key in step_config_exclude:
            continue

        if field_info.annotation == Optional[SourceWithValidator]:  # type: ignore[comparison-overlap]
            generic_step_fields[key] = (Optional[str], field_info)
        else:
            generic_step_fields[key] = (field_info.annotation, field_info)

    if experiment_trackers:
        experiment_tracker_enum = Enum(  # type: ignore[misc]
            "ExperimentTrackers", {e: e for e in experiment_trackers}
        )
        generic_step_fields["experiment_tracker"] = (
            Optional[Union[experiment_tracker_enum, bool]],
            None,
        )
    if step_operators:
        step_operator_enum = Enum(  # type: ignore[misc]
            "StepOperators", {s: s for s in step_operators}
        )
        generic_step_fields["step_operator"] = (
            Optional[Union[step_operator_enum, bool]],
            None,
        )

    generic_step_fields["settings"] = (Optional[settings_model], None)

    all_steps: Dict[str, Any] = {}
    all_steps_required = False
    for step_name, step in step_configurations.items():
        step_fields: Dict[str, Any] = {}

        if step.config.parameters:
            parameter_fields: Dict[str, Any] = {}

            for parameter_name in step.config.parameters:
                # Pydantic doesn't allow field names to start with an underscore
                sanitized_parameter_name = parameter_name.lstrip("_")
                while sanitized_parameter_name in parameter_fields:
                    sanitized_parameter_name = sanitized_parameter_name + "_"

                parameter_fields[sanitized_parameter_name] = (
                    Any,
                    FieldInfo(default=..., validation_alias=parameter_name),
                )

            parameters_class = create_model(
                f"{step_name}_parameters", **parameter_fields
            )
            step_fields["parameters"] = (
                parameters_class,
                FieldInfo(default=...),
            )

        step_fields.update(generic_step_fields)
        step_model = create_model(step_name, **step_fields)

        # Pydantic doesn't allow field names to start with an underscore
        sanitized_step_name = step_name.lstrip("_")
        while sanitized_step_name in all_steps:
            sanitized_step_name = sanitized_step_name + "_"

        if step.config.parameters:
            # This step has required parameters -> we make this attribute
            # required and also the parent attribute so these parameters must
            # always be included
            all_steps_required = True
            all_steps[sanitized_step_name] = (
                step_model,
                FieldInfo(default=..., validation_alias=step_name),
            )
        else:
            all_steps[sanitized_step_name] = (
                Optional[step_model],
                FieldInfo(default=None, validation_alias=step_name),
            )

    all_steps_model = create_model("Steps", **all_steps)

    top_level_fields: Dict[str, Any] = {}

    if snapshot.is_dynamic:
        pipeline_parameter_fields: Dict[str, Any] = {}

        for parameter_name in pipeline_configuration.parameters or {}:
            # Pydantic doesn't allow field names to start with an underscore
            sanitized_parameter_name = parameter_name.lstrip("_")
            while sanitized_parameter_name in pipeline_parameter_fields:
                sanitized_parameter_name = sanitized_parameter_name + "_"

            pipeline_parameter_fields[sanitized_parameter_name] = (
                Any,
                FieldInfo(default=..., validation_alias=parameter_name),
            )

        parameters_class = create_model(
            "Parameters", **pipeline_parameter_fields
        )
        top_level_fields["parameters"] = (
            parameters_class,
            FieldInfo(default=None),
        )

    for key, field_info in PipelineRunConfiguration.model_fields.items():
        if key in ["schedule", "build", "steps", "settings", "parameters"]:
            continue

        if field_info.annotation == Optional[SourceWithValidator]:  # type: ignore[comparison-overlap]
            top_level_fields[key] = (Optional[str], field_info)
        else:
            top_level_fields[key] = (field_info.annotation, field_info)

    top_level_fields["settings"] = (Optional[settings_model], None)

    if all_steps_required:
        top_level_fields["steps"] = (all_steps_model, FieldInfo(default=...))
    else:
        top_level_fields["steps"] = (
            Optional[all_steps_model],
            FieldInfo(default=None),
        )

    return create_model("Result", **top_level_fields).model_json_schema()  # type: ignore[no-any-return]
